# %%
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import torch.nn.init as init
import matplotlib.pyplot as plt
from torch_geometric.nn.conv.gcn_conv import gcn_norm

# %%
def generateAdj(N, p1, p2, c=100, sigma=0.01):
    # Create a matrix with random values from a normal distribution
    noise = torch.normal(mean=0, std=sigma, size=(N, N))

    # Create probability matrix
    prob_matrix = torch.full((N, N), p1)
    prob_matrix[c:, :] += (p2 - p1)
    prob_matrix[:, c:] += (p2 - p1)
    prob_matrix += noise
    prob_matrix = torch.clamp(prob_matrix, 0, 1)

    # Generate symmetric adjacency matrix
    mask = torch.triu(torch.ones(N, N), diagonal=1)
    adj_matrix = torch.bernoulli(prob_matrix) * mask
    symmetric_adj_matrix = adj_matrix + adj_matrix.T.clone()

    # Create edge_index from adjacency matrix
    edge_index = torch.nonzero(symmetric_adj_matrix, as_tuple=False).t()

    return edge_index

def generate_data(A_star, X, W, V, C, alpha):
    # A_star N*N, X N*d, W d*p, V K*p, C p*K,
    node_labels_F = torch.mm(A_star, torch.mm(X, W))
    node_labels_F = torch.mm(node_labels_F, C) # N*k
    node_labels_G_F_sin = torch.sin(torch.mm(A_star, torch.mm(node_labels_F, V)))
    node_labels_G_F_tanh = torch.tanh(node_labels_G_F_sin)
    node_labels_G_F_tanh = torch.tanh(torch.mm(A_star, torch.mm(node_labels_F, V)))
    node_labels_G_F = torch.mm(node_labels_G_F_sin * node_labels_G_F_tanh, C) # N*k
    # node_labels_G_F = torch.mm(node_labels_G_F_tanh, C) # N*k
    # node_labels_G_F = torch.matmul((torch.matmul(A_star, torch.matmul(node_labels_F, V))**4), C) #N*k
    node_labels_H = node_labels_F + alpha * node_labels_G_F
    return node_labels_H

def split_masks(num_nodes, train_rate, validate_rate, test_rate):
    assert train_rate + validate_rate + test_rate == 1.0, "Rates don't sum up to 1."

    # Create an array of zeros
    masks = np.zeros(num_nodes, dtype=np.bool)

    # Create node indices and shuffle them
    indices = np.arange(num_nodes)
    np.random.shuffle(indices)

    # Set corresponding indices to True based on the rates
    train_end = int(num_nodes * train_rate)
    validate_end = train_end + int(num_nodes * validate_rate)

    masks[indices[:train_end]] = True
    train_mask = torch.tensor(masks)

    masks[:] = False
    masks[indices[train_end:validate_end]] = True
    validate_mask = torch.tensor(masks)

    masks[:] = False
    masks[indices[validate_end:]] = True
    test_mask = torch.tensor(masks)
    
    return train_mask, validate_mask, test_mask

def train(A1, A2):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, A1, A2)
    loss = F.mse_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)

@torch.no_grad()
def test(A1, A2):
    model.eval()
    out = model(data.x, A1, A2)

    losses = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        loss = F.mse_loss(out[mask], data.y[mask])
        losses.append(float(loss.item()))

    return losses

def generate_normalized_adj_matrix(edge_index, num_nodes):
    # Create adjacency matrix
    adj_matrix = torch.sparse.FloatTensor(edge_index, torch.ones(edge_index.size(1)), torch.Size([num_nodes, num_nodes]))

    # Convert to dense for further operations
    adj_matrix = adj_matrix.to_dense()

    # Add self-loops
    adj_matrix += torch.eye(num_nodes)

    # Compute the degree matrix D
    degree_matrix = torch.diag(torch.sum(adj_matrix, dim=1))

    # Compute the inverse square root of the degree matrix
    D_inv_sqrt = torch.diag(1 / torch.sqrt(torch.diag(degree_matrix)))

    # Compute the normalized adjacency matrix
    normalized_adj_matrix = D_inv_sqrt @ adj_matrix @ D_inv_sqrt

    return normalized_adj_matrix

def generate_edge_index(N=2000, min_degree=1, max_degree=200, mean_degree=200, std_degree=20):
    # Generate normal random variables
    degrees = np.random.normal(loc=mean_degree, scale=std_degree, size=N).astype(int)
    
    # Truncate values to be within the desired range
    degrees = np.clip(degrees, min_degree, max_degree)

    rows = []
    cols = []

    for node, degree in enumerate(degrees):
        # Select 'degree' neighbors for the current node
        neighbors = np.random.choice(np.delete(np.arange(N), node), degree, replace=False)
        rows.extend([node] * degree)
        cols.extend(neighbors)

    # Create a tensor for row and column indices
    edge_index = torch.tensor([rows, cols], dtype=torch.int64)

    return edge_index

def sampling_Astar(edge_index,p_sampling):
    norm_edge, norm_value = gcn_norm(edge_index, add_self_loops=False)
    num_values_to_keep = int(norm_value.numel() * p_sampling) #sample less 0.5
    _, top_indices = torch.topk(norm_value, k=num_values_to_keep)
    edge_index = norm_edge[:,top_indices]
    return edge_index

def sampling(edge_index, p_sampling, p_random=0.05):
    # p_random = 0.1
    norm_edge, norm_value = gcn_norm(edge_index, add_self_loops=False)
    num_values_to_keep = int(norm_value.numel() * p_sampling) #sample less 0.5
    num_values_to_delete = norm_value.numel() - num_values_to_keep
    k=num_values_to_keep
    top_values, top_indices = torch.topk(norm_value, k=num_values_to_keep)
    other_indices = torch.nonzero(torch.lt(norm_value, top_values[k-1])).squeeze()
    
    if num_values_to_keep>=num_values_to_delete:
        num_values_to_random = int(num_values_to_delete*p_random)+1
    
        random_indices = torch.randperm(len(top_indices))[:-num_values_to_random]
        top_indices = top_indices[random_indices]
    
        random_indices = torch.randperm(len(other_indices))[:num_values_to_random]
        other_indices = other_indices[random_indices]

        choose_indices = torch.cat((top_indices, other_indices))
    else:
        num_values_to_random = int(num_values_to_keep*p_random)+1
    
        random_indices = torch.randperm(len(top_indices))[:-num_values_to_random]
        top_indices = top_indices[random_indices]
    
        random_indices = torch.randperm(len(other_indices))[:num_values_to_random]
        other_indices = other_indices[random_indices]

        choose_indices = torch.cat((top_indices, other_indices))
        
    edge_index = norm_edge[:,choose_indices]
    return edge_index


# %%
N = 2000
edge_index = generate_edge_index(mean_degree=200)


# %%
# Constants
d = 100  # Feature dimension
m = 20  # Intermediate dimension
k = 5  # Number of output dimensions
alpha = 5  # Hyperparameter (adjust as needed)

# Randomly generate X, A_star, W, V, and C
X = torch.randn(N, d)
W = torch.randn(d, m)
V = torch.randn(k, m)
C = torch.randn(m, k)


_, edge_weight = gcn_norm(edge_index) # Your output from gcn_norm

# Create an empty matrix with the right size
num_nodes = N
# edge_index_star = sampling_Astar(edge_index, 0.8)
A_star = generate_normalized_adj_matrix(edge_index, num_nodes)
# A_s = delete_small_edges_percentile(A_star, percentile=10)

# Generate y using the previously defined function
y = generate_data(A_star, X, W, V, C, alpha)

# Create a PyTorch Geometric Data object
data = Data(x=X, edge_index=edge_index, y=y)


# %%
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_units, K):
        super(GCN, self).__init__()
        self.layer_1 = nn.Linear(in_channels, hidden_units)
        self.layer_2 = nn.Linear(hidden_units, hidden_units)
        self.out = nn.Linear(hidden_units, K)
        nn.init.normal_(self.out.weight, mean=0.0, std=1.0)

    def forward(self, x, normalized_adj_matrix1, normalized_adj_matrix2):
        hidden_1 = F.relu(normalized_adj_matrix1 @ self.layer_1(x))
        hidden_2 = F.relu(normalized_adj_matrix2 @ self.layer_2(hidden_1))
        added_12 = hidden_1 + hidden_2
        logits = self.out(added_12)
        return logits

class GCN1(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.out_fc = nn.Linear(hidden_channels, out_channels)
        nn.init.normal_(self.out_fc.weight, mean=0.0, std=1.0)
        
    def forward(self, x, edge_index1, edge_index2):
        x1 = F.relu(self.conv1(x, edge_index1))
        x2 = F.relu(self.conv2(x1, edge_index2))
        x_out = x1 + x2
        x_out = self.out_fc(x_out)
        return x_out

# %%
train_rate = 0.6

validate_rate = 0.2
test_rate = 0.8 - train_rate

data.train_mask, data.val_mask, data.test_mask = split_masks(num_nodes, train_rate, validate_rate, test_rate)

# %%
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# edge_index_t = sampling(edge_index,0.5)
# A_t = generate_normalized_adj_matrix(edge_index_t, num_nodes)
# A_t = A_t.to(device)
data = data.to(device)
A_star = A_star.to(device)
# one_norm = torch.norm(A_star, p=1, dim=0).max()
# print(one_norm)

# %%
in_channels = X.shape[1]
hidden_channels = 50 # You can choose a different number based on your requirements
out_channels = y.shape[1]  # Assuming y contains class labels


# model = GCN(in_channels, hidden_channels, out_channels).to(device)
model = GCN1(in_channels, hidden_channels, out_channels).to(device)
lr = 1e-3 # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-5)
 
# def run_experiment(sampling_rate, configuration, num_epochs=3001):
#     edge_index_t = sampling(edge_index, sampling_rate)
#     A_t = generate_normalized_adj_matrix(edge_index_t, num_nodes)
#     A_t = A_t.to(device)
    
#     if configuration == "shallow":
#         A1 = A_t
#         A2 = A_star
#     elif configuration == "deep":
#         A2 = A_t
#         A1 = A_star
#     elif configuration == "both":
#         A2 = A_t
#         A1 = A_t
#     else:
#         raise ValueError("Invalid configuration")
    
#     train_losses = []
#     val_losses = []
#     test_losses = []
#     min_val_loss = float('inf')
#     best_test_loss = None

#     for epoch in range(1, num_epochs):
#         train_loss = train(A1, A2)
#         train_loss, val_loss, test_loss = test(A1, A2)
#         train_losses.append(train_loss)
#         val_losses.append(val_loss)
#         test_losses.append(test_loss)
#         if val_loss < min_val_loss:
#             min_val_loss = val_loss
#             best_test_loss = test_loss

#     return best_test_loss

def run_experiment(sampling_rate1, sampling_rate2, num_epochs=1250, num_runs=1):
    best_test_losses = []
    print(f"Sampling rate: {sampling_rate1}, {sampling_rate2}")
    for _ in range(num_runs):
        # one_norm = torch.norm(A_star - A_t, p=1, dim=0).max().item()
        # one_norms.append(one_norm)

        train_losses = []
        val_losses = []
        test_losses = []
        min_val_loss = float('inf')
        best_test_loss = None

        for _ in range(1, num_epochs):
            # edge_index_t1 = sampling(edge_index, sampling_rate1)
            # edge_index_t2 = sampling(edge_index, sampling_rate2)
            
            # A_t1 = generate_normalized_adj_matrix(edge_index_t1, num_nodes)
            # A_t2 = generate_normalized_adj_matrix(edge_index_t2, num_nodes)
            # A_t1 = A_t1.to(device)
            # A_t2 = A_t2.to(device)
            # print(f'{num_epochs}')
            edge_index_t1 = sampling(data.edge_index, sampling_rate1)
            edge_index_t2 = sampling(data.edge_index, sampling_rate2)

            train_loss = train(edge_index_t1, edge_index_t2)
            train_loss, val_loss, test_loss = test(edge_index_t1, edge_index_t2)
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            test_losses.append(test_loss)
            if val_loss < min_val_loss:
                min_val_loss = val_loss
                best_test_loss = test_loss

        best_test_losses.append(best_test_loss)

    avg_best_test_loss = sum(best_test_losses) / num_runs

    return avg_best_test_loss


avg_best_losses = {}
avg_one_norms = {}
sampling_rates = np.arange(0.1, 1.01, 0.1)


for rate1 in sampling_rates:
    for rate2 in sampling_rates:
        rate1 = round(rate1, 1)
        rate2 = round(rate2, 1)
        avg_best_loss = run_experiment(rate1, rate2)
        avg_best_losses[(rate1, rate2)] = avg_best_loss
        
# Saving the results to numpy files
one_norm = torch.norm(A_star, p=1, dim=0).max()
print(one_norm)
np.save('figure_A_S/2.avg_best_losses_2d.npy', avg_best_losses)

# markers = ['o', 's', 'D'] # Markers for three lines
# colors = ['b', 'g', 'r'] # Colors for three lines
# linewidths = 5 # Linewidths for three lines
# fontsize = 20

# plt.figure(figsize=(8, 6), dpi=300)
# # Loop through the three configurations
# for index, config in enumerate(configurations):
#     losses = [avg_best_losses[(config, rate)] for rate in sampling_rates]
#     plt.plot(sampling_rates, np.log10(losses), marker=markers[index], markersize=fontsize,
#              color=colors[index], linewidth = linewidths, label=f'sampling {config}')

# plt.grid(which='both', linestyle='--', linewidth=0.5)
# plt.xlabel(r'$Sampling rate$', fontsize=fontsize)
# plt.ylabel('Test error', fontsize=fontsize)
# plt.legend(fontsize=15)
# plt.xticks(fontsize=fontsize) # Customize the x-axis tick labels
# plt.yticks(fontsize=fontsize)
# plt.show()
# plt.savefig('figure_A_S/final_test_loss_plot.png', dpi=300)
